查看原文
其他

【强基固本】深度学习入门与Pytorch|4.2 变分自编码器(VAE)介绍、推导及代码

“强基固本,行稳致远”,科学研究离不开理论基础,人工智能学科更是需要数学、物理、神经科学等基础学科提供有力支撑,为了紧扣时代脉搏,我们推出“强基固本”专栏,讲解AI领域的基础知识,为你的科研学习提供助力,夯实理论基础,提升原始创新能力,敬请关注。

来源:知乎—aHiiLn

地址:https://zhuanlan.zhihu.com/p/433162159


01

基于内容的自编码器的局限性
上一节的自编码器中,我们使用手写数字数据集(mnist)训练自编码器。在已经训练好的自编码器中,对于每一个手写数字我们会产生相应的编码,当我们对解码器输入相应的编码时往往能够很好地还原当时的手写数字,然而当我们对解码器输入一个训练集中未出现的编码时,可能会发现输出的内容居然是噪声,这说明和和mnist完全没有关系。从这个角度来说,上一节提到的自编码器不算是一种生成模型,因为他并没有产生新的内容。
举个简单的例子来理解。比如我们把⚪、 和口都编码成一个实数,即只有一维的编码:
对于训练集,解码器能无损失地重建形状。如果我们对解码器输入一个训练集中没有出现过的编码,解码器会解码出什么呢?
这就说明了这个自动编码器不能真正生成新的内容。如果自动编码器足够强大,那么它能将任何N个初始训练数据放到实轴上并对他们进行无损地重建。这时自动编码器的高自由度会使编码和解码不会丢失任何信息(尽管编码空间的维数很低),这也意味着严重的过拟合,因此编码空间的某些点(正如上图里实数轴上的紫点)被解码成无意义的内容。
仔细想想,自动编码器被单独训练以尽可能少损失地进行编码和解码,并没有对编码空间的组织结构进行强制要求,所以网络会利用任何过拟合的可能性来尽可能完成任务。除非我们对其进行正则化。


02

VAE的定义
为了使自动编码器能够真正生成新的内容,我们要确保编码空间有足够的规律性。获得规律性的一种方案时在训练过程中显式地引入正则化。所以我们引入变分-自动编码器(VAE):VAE被定义为一种自编码器,其训练被正则化以避免过拟合并确保编码空间具有良好的规律性以支持生成过程。
VAE也和标准的自动编码器一样由编码器和解码器组成,训练使编码解码数据和初始数据之间的重建误差最小化。不同的是,为了正则化,我们不将输入数据编码成单个点,而是将输入数据编码成编码空间上的分布。如果说自动编码器通过训练数据学习到的是某个确定的函数,那么VAE希望能够基于训练数据学习到参数的概率分布。
训练如下:
1. 输入数据被编码成编码空间上的分布。
2. 从编码空间中采样出一个点
3. 对采样点进行解码,计算重建误差
4. 重建误差通过网络反向传播
编码分布被选为正态分布,此外,在训练 VAE 时最小化的损失函数由一个“重构项”和一个“正则化项”组成。它倾向于通过使编码器返回的分布接近标准正态分布来规范潜在空间的组织。该正则化项表示为返回分布与标准高斯分布之间的Kulback-Leibler 散度(KL散度),KL散度也叫相对熵,其理论意义在于度量两个概率分布之间的差异程度,KL散度越高说明两个分布之间的差异程度越大,如果两个分布相同则KL散度为0。
最后,编码空间的规律性主要通过两个属性表达:连续性(编码空间中相近的两个点一旦解码不应给出两个完全不同的内容)和完整性(对于选定的分布,从编码空间中采样得到的点在解码后应给出“有意义”的内容)。
然而将输入编码为分布并不能确保连续性和完整性,可能会返回具有微小方差的分布或者具有均值差别很大的分布,为此我们必须对协方差矩阵和编码器返回的分布的均值进行正则化。在实践中我们通过强制要求协方差矩阵接近正态分布而正则化,这样编码分布彼此之间不会相距太远。
通过正则化,我们就可以获得连续性和完整性,并在编码空间中编码的信息上创建“梯度”。

03

VAE的数学推导
用x表示我们数据的变量,并假设x是未直接观察到的潜在变量z(编码表示)生成的。对每个数据点,假设生成步骤:
1. 从先验分布p(z)中采样一个编码z
2. 数据x从条件似然分布p(x|z)中采样

流程图

在这样的概率模型中,重新定义编码器和解码器并考虑他们的概率版本。概率编码器由p(z|x)定义,其描述了给定解码变量的编码变量;概率解码器由p(x|z)定义,描述了给定编码变量的解码变量。编码空间z遵循先验分布p(z)。通过贝叶斯定理,可以建立三者的联系:
现在假设p(z)为标准高斯分布,p(x|z)是均值由z的确定性函数f定义、协方差为正常数c乘单位矩阵I定义的高斯分布。f属于F族,所以:
到此处,理论上我们可以通过贝叶斯定理计算p(z|x),但实际上因为分母上的积分,这种计算往往较难处理,所以需要使用变分推理来近似p(z|x)。
变分推理的思路是设置一个参数化的分布族并在这个族中寻找目标分布的最佳近似值。该族中最好的元素是最小化给定的近似误差测量值(大部分情况为KL散度),并通过该族的参数的梯度下降找到最佳近似值。
这里,使用高斯分布    来近p(z|x),其均值与协方差由参数x的两个函数g和h定义。g和h分别属于G和H族,因此:
现在需要通过优化g和f的参数来找到最佳近似值,以最小化近似值和p(z|x)的差距,即:
在倒数第二个方程,我们可以观察到在最大化“观察“的可能性(第一项的预期对数似然的最大化)和保持接近先验分布(第二项中    和p(z)的KL散度最小化)。这种权衡表达了我们对数据的信心和我们对先验的信心之间的平衡。
我们假设函数f已知且固定时,我们可以通过上面的推导近似后验p(z|x)。然而在实践中,定义解码器的函数f是未知的,需要选择。对于F中的任何函数f,我们都可以得到p(z|x)的最佳近似值,表示为    。但我们希望能够更加高效:对于给定的输入x,当我们从分布    中采样并从分布p(x|z)中采样    时,我们希望最大化    的概率。所以我们找到最优的    ,使:
   由f确定并可以由前面的推导获得。整理一下,我们要找到最佳    ,使得:
我们可以在这个目标函数中识别上一节给出的VAE直观描述中引入的元素:x和f(z)之间的重构误差以及    和p(z)之间的 KL 散度给出的正则化项)(这是一个标准的高斯)。我们还可以注意到控制前两项之间平衡的常数c。c越高,我们假设模型中概率解码器在f(z)附近的高方差越大,我们越重视正则化项(如果c低,则相反)。


04

VAE引入神经网络
上一节最后我们建立了依赖于函数f、g、h的概率模型,并使用变分推理表达了优化问题,得到的    和    就是此模型的编码-解码方案。将f、g、h表达成神经网络,那么F、G、H就对应了网络构架定义的函数族,只要优化这些网络上的参数即可。实际上,g和h并不是由两个完全独立的网络定义的,而是共享一部分权重,所以:
它定义了    的协方差阵,因此h(x)是一个方阵。然而,为了简化计算并减少参数数量,我们额外假设我们的p(z|x)近似值    是具有对角协方差矩阵的多维高斯分布(变量独立假设)。在这个假设下,h(x)只是协方差矩阵的对角元素的向量,并且具有与g(x)相同的大小。然而,我们通过这种方式减少了我们考虑用于变分推理的分布族,因此,获得的p(z|x)的近似值可能不太准确。
采样后得到z,通过解码器f得到    。
连接两个过程就是VAE了。
最后我们要注意编码器返回的分布中采样的形式。采样过程中我们必须使得误差能够通过网络反向传播。这里使用重新参数化技巧使梯度下降成为可能。尽管随机采样发生在架构的中途,并且包括使用以下事实:如果 z 是遵循具有均值 g(x) 的高斯分布的随机变量,并且协方差 那么它可以表示为:   
最后,以这种方式获得的变分自动编码器架构的目标函数由上一小节的最后一个方程给出,其中理论期望被或多或少准确的蒙特卡罗近似所取代,大多数情况下,该近似包含一个单抽。因此,考虑到这种近似并表示    ,我们恢复了上一节中直观推导出的损失函数,由重构项、正则化项和定义这两项相对权重的常数组成。


05

VAE用于mnist的代码
载入包和数据准备
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.autograd import Variablefrom torchvision.utils import save_imagebs = 100# MNIST Datasettrain_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)# Data Loader (Input Pipeline)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)
定义VAE的类
class VAE(nn.Module): def __init__(self, x_dim, h_dim1, h_dim2, z_dim): super(VAE, self).__init__()
# encoder part self.fc1 = nn.Linear(x_dim, h_dim1) self.fc2 = nn.Linear(h_dim1, h_dim2) self.fc31 = nn.Linear(h_dim2, z_dim) self.fc32 = nn.Linear(h_dim2, z_dim) # decoder part self.fc4 = nn.Linear(z_dim, h_dim2) self.fc5 = nn.Linear(h_dim2, h_dim1) self.fc6 = nn.Linear(h_dim1, x_dim)
def encoder(self, x): h = F.relu(self.fc1(x)) h = F.relu(self.fc2(h)) return self.fc31(h), self.fc32(h) # mu, log_var
def sampling(self, mu, log_var): std = torch.exp(0.5*log_var) eps = torch.randn_like(std) return eps.mul(std).add_(mu) # return z sample
def decoder(self, z): h = F.relu(self.fc4(z)) h = F.relu(self.fc5(h)) return F.sigmoid(self.fc6(h))
def forward(self, x): mu, log_var = self.encoder(x.view(-1, 784)) z = self.sampling(mu, log_var) return self.decoder(z), mu, log_var
# build modelvae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)if torch.cuda.is_available():    vae.cuda()
优化器和损失函数
optimizer = optim.Adam(vae.parameters())# return reconstruction error + KL divergence lossesdef loss_function(recon_x, x, mu, log_var): BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) return BCE + KLD
定义训练函数和测试函数
def train(epoch): vae.train() train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): data = data.cuda() optimizer.zero_grad()
recon_batch, mu, log_var = vae(data) loss = loss_function(recon_batch, data, mu, log_var)
loss.backward() train_loss += loss.item() optimizer.step()
if batch_idx % 100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item() / len(data))) print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))def test(): vae.eval() test_loss= 0 with torch.no_grad(): for data, _ in test_loader: data = data.cuda() recon, mu, log_var = vae(data)
# sum up batch loss test_loss += loss_function(recon, data, mu, log_var).item()
test_loss /= len(test_loader.dataset) print('====> Test set loss: {:.4f}'.format(test_loss))
训练
for epoch in range(1, 51): train(epoch) test()

参考:

https://github.com/lyeoni/pytorch-mnist-VAE

https://towardsdatascience.com/

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。


“强基固本”历史文章


更多强基固本专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存